{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MNIST Training and Prediction with SageMaker Chainer\n",
    "\n",
    "[MNIST](http://yann.lecun.com/exdb/mnist/), the \"Hello World\" of machine learning, is a popular dataset for handwritten digit classification. It consists of 70,000 28x28 grayscale images labeled in 10 digit classes (0 to 9). This tutorial will show how to train a model to predict handwritten digits on the MNIST dataset by running a Chainer script on SageMaker using the sagemaker-python-sdk.\n",
    "\n",
    "For more information about the Chainer container, see the sagemaker-chainer-containers repository and the sagemaker-python-sdk repository:\n",
    "\n",
    "* https://github.com/aws/sagemaker-chainer-containers\n",
    "* https://github.com/aws/sagemaker-python-sdk\n",
    "\n",
    "For more on Chainer, please visit the Chainer repository:\n",
    "\n",
    "* https://github.com/chainer/chainer\n",
    "\n",
    "This notebook is adapted from the [MNIST](https://github.com/chainer/chainer/tree/master/examples/mnist) example in the Chainer repository."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker import get_execution_role\n",
    "\n",
    "sagemaker_session = sagemaker.Session()\n",
    "\n",
    "# Get a SageMaker-compatible role used by this Notebook Instance.\n",
    "role = get_execution_role()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows how to use the SageMaker Python SDK to run your code in a local container before deploying to SageMaker's managed training or hosting environments. Just change your estimator's train_instance_type to `local` or `local_gpu`. For more information, see [local mode](https://github.com/aws/sagemaker-python-sdk#local-mode).\n",
    "\n",
    "In order to use this feature you'll need to install docker-compose (and nvidia-docker if training with a GPU). Running following script will install docker-compose or nvidia-docker-compose and configure the notebook environment for you.\n",
    "\n",
    "Note, you can only run a single local notebook at a time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!/bin/bash ./setup.sh"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download MNIST datasets\n",
    "\n",
    "We can use Chainer's built-in `get_mnist()` method to download, import and preprocess the MNIST dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import chainer\n",
    "\n",
    "train, test = chainer.datasets.get_mnist()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parse, save, and upload the data\n",
    "\n",
    "We save our data, then use `sagemaker_session.upload_data` to upload the data to an S3 location used for training. The return value identifies the S3 path to the uploaded data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import numpy as np\n",
    "\n",
    "train_images = np.array([data[0] for data in train])\n",
    "train_labels = np.array([data[1] for data in train])\n",
    "test_images = np.array([data[0] for data in test])\n",
    "test_labels = np.array([data[1] for data in test])\n",
    "\n",
    "try:\n",
    "    os.makedirs('/tmp/data/train')\n",
    "    os.makedirs('/tmp/data/test')\n",
    "\n",
    "    np.savez('/tmp/data/train/train.npz', images=train_images, labels=train_labels)\n",
    "    np.savez('/tmp/data/test/test.npz', images=test_images, labels=test_labels)\n",
    "\n",
    "    train_input = sagemaker_session.upload_data(\n",
    "        path=os.path.join('/tmp/data', 'train'),\n",
    "        key_prefix='notebook/chainer/mnist')\n",
    "    test_input = sagemaker_session.upload_data(\n",
    "        path=os.path.join('/tmp/data', 'test'),\n",
    "        key_prefix='notebook/chainer/mnist')\n",
    "finally:\n",
    "    shutil.rmtree('/tmp/data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Writing the Chainer script to run on Amazon SageMaker\n",
    "\n",
    "### Training\n",
    "\n",
    "We need to provide a training script that can run on the SageMaker platform. The training script is very similar to a training script you might run outside of SageMaker, but you can access useful properties about the training environment through various environment variables, such as:\n",
    "\n",
    "* `SM_MODEL_DIR`: A string representing the path to the directory to write model artifacts to.\n",
    "  These artifacts are uploaded to S3 for model hosting.\n",
    "* `SM_NUM_GPUS`: An integer representing the number of GPUs available to the host.\n",
    "* `SM_OUTPUT_DIR`: A string representing the filesystem path to write output artifacts to. Output artifacts may\n",
    "  include checkpoints, graphs, and other files to save, not including model artifacts. These artifacts are compressed\n",
    "  and uploaded to S3 to the same S3 prefix as the model artifacts.\n",
    "\n",
    "Supposing two input channels, 'train' and 'test', were used in the call to the Chainer estimator's ``fit()`` method,\n",
    "the following will be set, following the format `SM_CHANNEL_[channel_name]`:\n",
    "\n",
    "* `SM_CHANNEL_TRAIN`: A string representing the path to the directory containing data in the 'train' channel\n",
    "* `SM_CHANNEL_TEST`: Same as above, but for the 'test' channel.\n",
    "\n",
    "A typical training script loads data from the input channels, configures training with hyperparameters, trains a model, and saves a model to `model_dir` so that it can be hosted later. Hyperparameters are passed to your script as arguments and can be retrieved with an `argparse.ArgumentParser` instance. For example, the script run by this notebook starts with the following:\n",
    "\n",
    "```python\n",
    "import argparse\n",
    "import os\n",
    "\n",
    "if __name__=='__main__':\n",
    "\n",
    "    parser = argparse.ArgumentParser()\n",
    "\n",
    "    # hyperparameters sent by the client are passed as command-line arguments to the script.\n",
    "    parser.add_argument('--epochs', type=int, default=50)\n",
    "    parser.add_argument('--batch-size', type=int, default=64)\n",
    "\n",
    "    # Data and model checkpoints directories\n",
    "    parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])\n",
    "    parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])\n",
    "    parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])\n",
    "    parser.add_argument('--test', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])\n",
    "    \n",
    "    args, _ = parser.parse_known_args()\n",
    "    \n",
    "    num_gpus = int(os.environ['SM_NUM_GPUS'])\n",
    "    \n",
    "    # ... load from args.train and args.test, train a model, write model to args.model_dir.\n",
    "```\n",
    "\n",
    "Because the Chainer container imports your training script, you should always put your training code in a main guard (`if __name__=='__main__':`) so that the container does not inadvertently run your training code at the wrong point in execution.\n",
    "\n",
    "For more information about training environment variables, please visit https://github.com/aws/sagemaker-containers.\n",
    "\n",
    "### Hosting and Inference\n",
    "\n",
    "We use a single script to train and host the Chainer model. You can also write separate scripts for training and hosting. In contrast with the training script, the hosting script requires you to implement functions with particular function signatures (or rely on defaults for those functions).\n",
    "\n",
    "These functions load your model, deserialize data sent by a client, obtain inferences from your hosted model, and serialize predictions back to a client:\n",
    "\n",
    "* **`model_fn(model_dir)` (always required for hosting)**: This function is invoked to load model artifacts from those written into `model_dir` during training.\n",
    "\n",
    "This is the `model_fn` used in the script below during hosting:\n",
    "\n",
    "```python\n",
    "def model_fn(model_dir):\n",
    "    model = L.Classifier(MLP(1000, 10))\n",
    "    serializers.load_npz(os.path.join(model_dir, 'model.npz'), model)\n",
    "    return model.predictor\n",
    "```\n",
    "\n",
    "* `input_fn(input_data, content_type)`: This function is invoked to deserialize prediction data when a prediction request is made. The return value is passed to predict_fn. `input_data` is the serialized input data in the body of the prediction request, and `content_type`, the MIME type of the data.\n",
    "  \n",
    "  \n",
    "* `predict_fn(input_data, model)`: This function accepts the return value of `input_fn` as the `input_data` parameter and the return value of `model_fn` as the `model` parameter and returns inferences obtained from the model.\n",
    "  \n",
    "  \n",
    "* `output_fn(prediction, accept)`: This function is invoked to serialize the return value from `predict_fn`, which is passed in as the `prediction` parameter, back to the SageMaker client in response to prediction requests.\n",
    "\n",
    "\n",
    "`model_fn` is always required, but default implementations exist for the remaining functions. These default implementations can deserialize a NumPy array, invoking the model's `__call__` method on the input data, and serialize a NumPy array back to the client.\n",
    "\n",
    "This notebook relies on the default `input_fn`, `predict_fn`, and `output_fn` implementations. See the Chainer sentiment analysis notebook for an example of how one can implement these hosting functions.\n",
    "\n",
    "Please examine the script below. Training occurs behind the main guard, which prevents the function from being run when the script is imported, and `model_fn` loads the model saved into `model_dir` during training.\n",
    "\n",
    "For more on writing Chainer scripts to run on SageMaker, or for more on the Chainer container itself, please see the following repositories: \n",
    "\n",
    "* For writing Chainer scripts to run on SageMaker: https://github.com/aws/sagemaker-python-sdk\n",
    "* For more on the Chainer container and default hosting functions: https://github.com/aws/sagemaker-chainer-containers\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "!pygmentize 'chainer_mnist_single_machine.py'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create SageMaker chainer estimator\n",
    "\n",
    "To run our Chainer training script on SageMaker, we construct a `sagemaker.chainer.estimator.Chainer` estimator, which accepts several constructor arguments:\n",
    "\n",
    "* `entry_point`: The path to the Python script SageMaker runs for training and prediction.\n",
    "\n",
    "\n",
    "* `train_instance_count`: An integer representing how many training instances to start.\n",
    "\n",
    "\n",
    "* `train_instance_type`: The type of SageMaker instances for training. We pass the string `local` or `local_gpu` here to enable the local mode for training in the local environment. `local` is for cpu training and `local_gpu` is for gpu training. If you want to train on a remote instance, specify a SageMaker ML instance type here accordingly. See [Amazon SageMaker ML Instance Types](https://aws.amazon.com/sagemaker/pricing/instance-types/) for a list of instance types.\n",
    "\n",
    "\n",
    "* `hyperparameters`: A dictionary passed to the `train` function as `hyperparameters`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import subprocess\n",
    "\n",
    "from sagemaker.chainer.estimator import Chainer\n",
    "\n",
    "instance_type = 'local'\n",
    "\n",
    "if subprocess.call('nvidia-smi') == 0:\n",
    "    ## Set type to GPU if one is present\n",
    "    instance_type = 'local_gpu'\n",
    "    \n",
    "print(\"Instance type = \" + instance_type)\n",
    "\n",
    "chainer_estimator = Chainer(entry_point='chainer_mnist_single_machine.py', role=role,\n",
    "                            train_instance_count=1, train_instance_type=instance_type,\n",
    "                            hyperparameters={'epochs': 3, 'batch_size': 128})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train on MNIST data in S3\n",
    "\n",
    "After we've constructed our Chainer object, we can fit it using the MNIST data we uploaded to S3. SageMaker makes sure our data is available in the local filesystem, so our user script can simply read the data from disk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "chainer_estimator.fit({'train': train_input, 'test': test_input})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our user script writes various artifacts, such as plots, to a directory `output_data_dir`, the contents of which SageMaker uploads to S3. Now we download and extract these artifacts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    os.makedirs('output/single_machine_mnist')\n",
    "except OSError:\n",
    "    pass\n",
    "\n",
    "chainer_training_job = chainer_estimator.latest_training_job.name\n",
    "\n",
    "desc = chainer_estimator.sagemaker_session.sagemaker_client. \\\n",
    "           describe_training_job(TrainingJobName=chainer_training_job)\n",
    "output_data = desc['ModelArtifacts']['S3ModelArtifacts'].replace('model', 'output')\n",
    "!aws --region {sagemaker_session.boto_session.region_name} s3 cp {output_data} output/single_machine_mnist/output.tar\n",
    "!tar -xzvf output/single_machine_mnist/output.tar -C output/single_machine_mnist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "These plots show the accuracy and loss over each epoch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "from IPython.display import display\n",
    "\n",
    "accuracy_graph = Image(filename=\"output/single_machine_mnist/accuracy.png\",\n",
    "                       width=800,\n",
    "                       height=800)\n",
    "loss_graph = Image(filename=\"output/single_machine_mnist/loss.png\",\n",
    "                   width=800,\n",
    "                   height=800)\n",
    "\n",
    "display(accuracy_graph, loss_graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deploy model to endpoint\n",
    "\n",
    "After training, we deploy the model to an endpoint. Here we also specify instance_type to be `local` or `local_gpu` to deploy the model to the local environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "predictor = chainer_estimator.deploy(initial_instance_count=1, instance_type=instance_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predict Hand-Written Digit\n",
    "\n",
    "We can use this predictor returned by `deploy` to send inference requests to our locally-hosted model. Let's get some random test images in MNIST first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "num_samples = 5\n",
    "indices = random.sample(range(test_images.shape[0] - 1), num_samples)\n",
    "images, labels = test_images[indices], test_labels[indices]\n",
    "\n",
    "for i in range(num_samples):\n",
    "    plt.subplot(1,num_samples,i+1)\n",
    "    plt.imshow(images[i].reshape(28, 28), cmap='gray')\n",
    "    plt.title(labels[i])\n",
    "    plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's see if we can make correct predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "prediction = predictor.predict(images)\n",
    "predicted_label = prediction.argmax(axis=1)\n",
    "print('The predicted labels are: {}'.format(predicted_label))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's get some test data from you! Drawing into the image box loads the pixel data into a variable named 'data' in this notebook, which we can then pass to the Chainer predictor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from IPython.display import HTML\n",
    "HTML(open(\"input.html\").read())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's see if your writing can be recognized!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "image = np.array(data, dtype=np.float32)\n",
    "prediction = predictor.predict(image)\n",
    "predicted_label = prediction.argmax(axis=1)[0]\n",
    "print('What you wrote is: {}'.format(predicted_label))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean resources\n",
    "\n",
    "After you have finished with this example, remember to delete the prediction endpoint to release the instance associated with it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "chainer_estimator.delete_endpoint()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_chainer_p36",
   "language": "python",
   "name": "conda_chainer_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
